{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import math\n",
    "\n",
    "FILE = \"fashion-mnist_train.csv\"\n",
    "with open(f\"../../data/{FILE}\") as f:\n",
    "    examples = f.read().strip().split(\"\\n\")[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "random.seed(0)\n",
    "train, valid, test = [], [], []\n",
    "valid_cutoff = .8\n",
    "test_cutoff = .9\n",
    "\n",
    "for example in examples:\n",
    "    label, pixels = example.split(\",\", 1)\n",
    "    label = int(label)\n",
    "    if label > 4:\n",
    "        continue\n",
    "    pixels = [int(p) for p in pixels.split(\",\")]\n",
    "    data = np.asarray(pixels, dtype=\"int32\")\n",
    "    data = data.reshape((1,28,28))\n",
    "    # Scale values from -1 to 1\n",
    "    data = ((data / 255) - .5) / .5\n",
    "    # Move channel axis to first axis\n",
    "    #data = np.moveaxis(data, -1, 0)\n",
    "    target = np.zeros((1, 5))\n",
    "    target[0,label] = 1\n",
    "    row = (data, target, )\n",
    "\n",
    "    split = random.random()\n",
    "    if split > test_cutoff:\n",
    "        test.append(row)\n",
    "    elif split > valid_cutoff:\n",
    "        valid.append(row)\n",
    "    else:\n",
    "        train.append(row)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "def log_loss(predicted, actual):\n",
    "    tol = 1e-6\n",
    "    cross_entropy = actual * np.log(predicted + tol)\n",
    "    return -np.sum(cross_entropy)\n",
    "\n",
    "def log_loss_grad(predicted, actual):\n",
    "    return predicted - actual\n",
    "\n",
    "def softmax(preds):\n",
    "    tol = 1e-6\n",
    "    preds = np.exp(preds - np.max(preds))\n",
    "    return preds / (np.sum(preds) - tol)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "def init_layers(layer_defs):\n",
    "    layers = []\n",
    "    for i in range(1, len(layer_defs)):\n",
    "        if \"input_units\" in layer_defs[i]:\n",
    "            last_units = layer_defs[i][\"input_units\"]\n",
    "        else:\n",
    "            last_units = layer_defs[i-1][\"units\"]\n",
    "\n",
    "        biases = np.ones((1,layer_defs[i][\"units\"]))\n",
    "        if layer_defs[i][\"type\"] == \"cnn\":\n",
    "            np.random.seed(0)\n",
    "            weights = np.random.rand(layer_defs[i-1][\"units\"], layer_defs[i][\"units\"], layer_defs[i][\"kernel_size\"], layer_defs[i][\"kernel_size\"])\n",
    "            weights = weights / 10\n",
    "        else:\n",
    "            np.random.seed(0)\n",
    "            weights = np.random.rand(last_units, layer_defs[i][\"units\"])\n",
    "            weights = weights / 5 - .1\n",
    "\n",
    "        layers.append([\n",
    "            weights,\n",
    "            biases,\n",
    "            layer_defs[i][\"type\"]\n",
    "        ])\n",
    "    return layers"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "from skimage.util import view_as_windows\n",
    "\n",
    "def unroll_image_manual(image, kernel_x, kernel_y):\n",
    "    x_size = (image.shape[0] - (kernel_x - 1))\n",
    "    y_size = (image.shape[1] - (kernel_y - 1))\n",
    "    rows =  x_size * y_size\n",
    "    unrolled = np.zeros((rows, kernel_x * kernel_y))\n",
    "    for x in range(0, x_size):\n",
    "        for y in range(0, y_size):\n",
    "            unrolled[y + (x * y_size),:] = image[x:(x+kernel_x),y:(y+kernel_y)].reshape((1,kernel_x * kernel_y))\n",
    "    return unrolled\n",
    "\n",
    "def unroll_image(image, kernel_x, kernel_y):\n",
    "    unrolled = view_as_windows(image, (kernel_x, kernel_y))\n",
    "    x_size = (image.shape[0] - (kernel_x - 1))\n",
    "    y_size = (image.shape[1] - (kernel_y - 1))\n",
    "    rows =  x_size * y_size\n",
    "    return unrolled.reshape(rows, kernel_x * kernel_y)\n",
    "\n",
    "def convolve(image, kernel):\n",
    "    return np.matmul(image, kernel.reshape(math.prod(kernel.shape), 1))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [],
   "source": [
    "def forward(batch, layers):\n",
    "    hidden = [batch.copy()]\n",
    "    for i in range(len(layers)):\n",
    "        if layers[i][2] == \"cnn\":\n",
    "            channels, next_channels, kernel_x, kernel_y = layers[i][0].shape\n",
    "\n",
    "            new_x = batch.shape[1] - (kernel_x - 1)\n",
    "            new_y = batch.shape[2] - (kernel_y - 1)\n",
    "            next_batch = np.zeros((next_channels, new_x , new_y))\n",
    "            for channel in range(channels):\n",
    "                unrolled = unroll_image(batch[channel,:], kernel_x, kernel_y)\n",
    "                for next_channel in range(next_channels):\n",
    "                    kernel = layers[i][0][channel, next_channel, :]\n",
    "                    mult = convolve(unrolled, kernel).reshape(new_x, new_y)\n",
    "                    next_batch[next_channel,:] += mult\n",
    "            next_batch /= batch.shape[0]\n",
    "\n",
    "            hidden.append(next_batch.copy())\n",
    "            next_batch = np.maximum(next_batch, 0)\n",
    "            batch = next_batch\n",
    "        else:\n",
    "            if layers[i-1][2] == \"cnn\" or i == 0:\n",
    "                batch = batch.reshape(1, math.prod(batch.shape))\n",
    "            batch = np.matmul(batch, layers[i][0]) + layers[i][1]\n",
    "            hidden.append(batch.copy())\n",
    "            if i < len(layers) - 1:\n",
    "                batch = np.maximum(batch, 0)\n",
    "\n",
    "    return batch, hidden"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "outputs": [],
   "source": [
    "def backward(layers, hidden, grad, lr, verbose=False):\n",
    "    for i in range(len(layers)-1, -1, -1):\n",
    "        print(f\"Layer {i}\") if verbose else None\n",
    "\n",
    "        if layers[i][2] == \"cnn\":\n",
    "            grad = grad.reshape(hidden[i+1].shape)\n",
    "            if i != len(layers) - 1:\n",
    "                grad = np.multiply(grad, np.heaviside(hidden[i+1], 0))\n",
    "            _, kernel_x, kernel_y = grad.shape\n",
    "            print(f\"Grad shape {grad.shape}\") if verbose else None\n",
    "\n",
    "            new_grad = np.zeros(hidden[i].shape)\n",
    "            for channel in range(hidden[i].shape[0]):\n",
    "                # With multi-channel output, you need to loop across the output grads to link to input channel kernels\n",
    "                # Each kernel gets a unique update\n",
    "                flat_input = unroll_image(hidden[i][channel,:], kernel_x, kernel_y)\n",
    "                for next_channel in range(hidden[i+1].shape[0]):\n",
    "                    # Kernel update\n",
    "                    channel_grad = grad[next_channel,:]\n",
    "                    k_grad = convolve(flat_input, channel_grad).reshape(layers[i][0].shape[2], layers[i][0].shape[3])\n",
    "                    grad_norm = math.prod(channel_grad.shape)\n",
    "                    layers[i][0][channel,next_channel,:] -= (k_grad * lr) / grad_norm\n",
    "                    print(f\"k_grad: {k_grad.shape}\") if verbose else None\n",
    "            for next_channel in range(hidden[i+1].shape[0]):\n",
    "                channel_grad = grad[next_channel,:]\n",
    "                kernel_x = (layers[i][0][0,next_channel,:].shape[0])\n",
    "                kernel_y = (layers[i][0][0,next_channel,:].shape[1])\n",
    "                padded_grad = np.pad(channel_grad, ((kernel_x - 1, kernel_x - 1),  (kernel_y - 1, kernel_y - 1)))\n",
    "                flat_padded = unroll_image(padded_grad, kernel_x, kernel_y)\n",
    "                for channel in range(hidden[i].shape[0]):\n",
    "                    # Grad to lower layer\n",
    "                    flipped_kernel = np.flip(layers[i][0][channel,next_channel,:], axis=[0,1])\n",
    "                    updated_grad = convolve(flat_padded, flipped_kernel).reshape(hidden[i].shape[1], hidden[i].shape[2])\n",
    "                    # Since we're multiplying each input by multiple kernel values, reduce the gradient accordingly\n",
    "                    # This will reduce the edges more than necessary (they contribute to fewer output values), but is simple to implement\n",
    "                    new_grad[channel, :] += updated_grad / math.prod(flipped_kernel.shape)\n",
    "            grad = new_grad\n",
    "        else:\n",
    "            if i != len(layers) - 1:\n",
    "                grad = np.multiply(grad, np.heaviside(hidden[i+1], 0))\n",
    "            grad = grad.T\n",
    "            print(f\"starting grad: {grad.shape}\") if verbose else None\n",
    "            w_grad = np.matmul(grad, hidden[i].reshape(1, math.prod(hidden[i].shape))).T\n",
    "            print(f\"w_grad: {w_grad.shape}\") if verbose else None\n",
    "            b_grad = grad.T\n",
    "\n",
    "            layers[i][0] -= w_grad * lr\n",
    "            layers[i][1] -= b_grad * lr\n",
    "\n",
    "            grad = np.matmul(layers[i][0], grad).T\n",
    "            print(f\"ending grad: {grad.shape}\") if verbose else None\n",
    "    return layers"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 loss: 0.7330912743388215\n",
      "Epoch 0 valid accuracy: 0.7805611222444889\n",
      "Epoch 1 loss: 0.596589137086033\n",
      "Epoch 1 valid accuracy: 0.8072812291249165\n",
      "Epoch 2 loss: 0.5335671947706974\n",
      "Epoch 2 valid accuracy: 0.8183032732130928\n",
      "Epoch 3 loss: 0.5067247472491165\n",
      "Epoch 3 valid accuracy: 0.8299933199732799\n",
      "Epoch 4 loss: 0.4908787078802125\n",
      "Epoch 4 valid accuracy: 0.8350033400133601\n"
     ]
    }
   ],
   "source": [
    "layer_defs = [\n",
    "    {\"type\": \"input\", \"units\": 1},\n",
    "    {\"type\": \"cnn\", \"input_units\": 1, \"units\": 1, \"kernel_size\": 3},\n",
    "    {\"type\": \"dense\", \"input_units\": 26 ** 2, \"units\": 5}\n",
    "]\n",
    "lr = 5e-3\n",
    "epochs = 5\n",
    "\n",
    "layers = init_layers(layer_defs)\n",
    "for epoch in range(epochs):\n",
    "    epoch_loss = 0\n",
    "    for i, img in enumerate(train):\n",
    "        image, target = img\n",
    "        batch, hidden = forward(image, layers)\n",
    "\n",
    "        grad = log_loss_grad(softmax(batch), target)\n",
    "        epoch_loss += log_loss(softmax(batch), target)\n",
    "        layers = backward(layers, hidden, grad, lr)\n",
    "\n",
    "    print(f\"Epoch {epoch} loss: {epoch_loss / len(train)}\")\n",
    "    match = np.zeros(len(valid))\n",
    "    for i, img in enumerate(valid):\n",
    "        image, target = img\n",
    "        valid_pred, _ = forward(image, layers)\n",
    "        match[i] = np.argmax(valid_pred) == np.argmax(target)\n",
    "\n",
    "    print(f\"Epoch {epoch} valid accuracy: {match.sum() / match.shape[0]}\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "outputs": [
    {
     "data": {
      "text/plain": "0.8531147540983607"
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_preds = []\n",
    "match = np.zeros(len(test))\n",
    "for i, img in enumerate(test):\n",
    "    image, target = img\n",
    "    test_pred, _ = forward(image, layers)\n",
    "    test_preds.append(softmax(test_pred))\n",
    "    match[i] = np.argmax(test_pred) == np.argmax(target)\n",
    "\n",
    "match.sum() / match.shape[0]"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
